
import os, sys


from bmlp.models.mlpmixer import MLPMixer, SMixer
from bmlp.models.masked_mlp import MaskedMLP

import torch
import numpy as np


sys.path.append( os.path.dirname(__file__) + "/../../" )
from tqdm import tqdm

import math
import gc
from torchvision.datasets import   CIFAR10, CIFAR100 #,STL10
from datasets.stl10 import STL10
from torchvision.transforms import transforms, AutoAugment, AutoAugmentPolicy
import torchvision
import torch.optim as optim

import argparse

import wandb
from bmlp.sc_balance import configure_connections






def adjust_learning_rate(optimizer, epoch, max_epoch, base_lr, cos=1, schedule=[]):
    """Decay the learning rate based on schedule"""
    lr = base_lr
    if cos==1:  # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * epoch / max_epoch))
    else:  # stepwise lr schedule
        for milestone in schedule:
            lr *= 0.1 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def warmup(optimizer, epoch, base_lr, warmup_epoch):
    if epoch <  warmup_epoch:
        lr = base_lr*(epoch/warmup_epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr



def test(loader, net, criterion, device="cuda"): 
    net.to(device)
    net.eval()
    total = 0.
    correct = 0.
    loss = 0.
    accuracy_list = []
    print("start test ...")
    count = 0
    with torch.no_grad():
        for x, l in tqdm(loader):
            x=x.to(device)
            l=l.to(device)
            y = net(x)
            total += l.shape[0]
            loss += criterion(y,l).item()
            _, predicted = torch.max(y.data, 1)
            if l.shape == y.shape:
                _, l = torch.max(l.data, 1)
            correct += (predicted == l).sum().item()    
            count += 1

    accuracy = correct/total
    loss = loss/count
    print("test accuracy: {:.4f}".format(accuracy))
    print("test loss: {:.6f}".format(loss))
    return accuracy, loss


def _correct(y,l):
    
    _, predicted = torch.max(y.data, 1)
    if l.shape == y.shape:
        _, l = torch.max(l.data, 1)
    correct = (predicted == l).sum().item()    
    return correct


def main():

    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('--aug',  type=int, default=4,
                        help='augmentation type ( 0, 1,..., 4)(default: %(default)s)')
    parser.add_argument('--batch',  type=int, default=128,
                        help='batch size for training (default: %(default)s)')
    parser.add_argument('--cos', type=int, default=1, 
                        help='cosine annealing')


    parser.add_argument("--dataset", type=str, default="CIFAR10",
                        help="dataset")

    parser.add_argument('--device_id',  type=int, default=0,
                        help='GPU id. Set -1 for cpu (default: %(default)s)')



    parser.add_argument('--dim', metavar='N', type=int, default=128, # 128 for bmlp
                        help=' input dim of channelMLP(default: %(default)s)')
    parser.add_argument('--dim_token', type=int, default=-1,
                        help=' input dim of tokenMLP(default: %(default)s)')
    
    parser.add_argument('--epoch',  type=int, default=200, #600 for bmlp
                        help='max epoch (default: %(default)s)')

    parser.add_argument("-j", '--job_name',  type=str, default="default",
                        help='job name to identify experiments (default: %(default)s)')


    parser.add_argument('--L', metavar='N', type=int, default=2,
                        help='Number of Layers of MLP (default: %(default)s)')

    parser.add_argument('--lr',  type=float, default=1e-1,
                        help='learning rate for training (default: %(default)s)')

    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum of SGD (default: %(default)s)')

    parser.add_argument('--num_connections',  type=int, default=2**21,
                        help='  0.5*\gamma*( dim**2*dim_token+ dim*dim_token**2)(default: %(default)s)')

    parser.add_argument('--num_params',  type=int, default=-1,
                        help=' max: dim**2+ dim_token**2(default: %(default)s)')

    parser.add_argument('--num_workers',  type=int, default=8,
                        help=' max: 64(default: %(default)s)')

    parser.add_argument('--net',  type=str, default="MLPMixer", #BMLP
                        help='MLPMixer, SMixer, MaskedMLP (default: %(default)s)')

    parser.add_argument("-p", '--patch_size', type=int, default=4,
                        help=' patch size of inputs(default: %(default)s)')
    parser.add_argument('--permute',type=int, default=0,
                    help='permute indices per every block. 0: none, 2: every blocks (default: %(default)s)')

    parser.add_argument("--perm_block_id", type=int, default=-1,
                        help=" id of perm  (default: %(default)s)")

    parser.add_argument('--prod_dim',  type=int, default=-1,
                        help=' dim*dim_token. If prod_dim > 0 , we set dim_token = prod_dim/dim (default: %(default)s)')    

    parser.add_argument('--max_dim',  type=int, default=-1,
                        help=' dim*dim_token*max(1, ef).(default: %(default)s)')    

    parser.add_argument('--resize',  type=int, default=32,
                        help=' img_size to resize(default: %(default)s)')    
    parser.add_argument('--schedule', type=int, default=-1 ,
                        help="schedule for lr decay  (default: %(default)s  )")

    parser.add_argument('--warmup_epoch',  type=int, default=0,
                        help='max epoch for warmup (default: %(default)s)')

    


    parser.add_argument('-ef', '--expansion_factor', type=float, default=-1, #0.5
                    help=' expansion_factor  for both MLPMixer and SMixer. If < 0,  each block is FC + Activation. (default: %(default)s)')
    parser.add_argument("--seed", type=int,default=42,
                        help="random seed (default: %(default)s)")
    parser.add_argument("--channels", type=int, default=3,
                        help=" c (default: %(default)s)")


    parser.add_argument("--fix", type=int,default=2,
                        help="0: fix dim*dim_token,  1: fix dim**2 + dim_token**2, 2: fix dim**2*dim_token + dim_token**2*dim (default: %(default)s)")

    parser.add_argument("--force_fc_in_skip", type=int,default=0,
                        help=" If it is not 0, use FC in skip-connection in first token-mlp even if S0 = S (default: %(default)s)")
    parser.add_argument("--opt", type=str, default="sgd",
                        help=" optimizer (sgd or adamw) (default: %(default)s)")

    """
    ### MLPMixer
    parser.add_argument('-eft', '--expansion_factor_token', type=float, default=4, #0.5
                    help=' token hidden dim = factor * token input dim  (default: %(default)s)')



    """
    ### MaskedMLP
    parser.add_argument("--freezing_rate", type=float, default=0,
                        help=" freezing_rate (default: %(default)s)")
    

    
    args = parser.parse_args()

    args.stop_criterion=0

    if args.fix == 2:
        if args.net in ["MLPMixer", "SMixer", "MLPMixerSymm", "MLPMixerPad", "MLPMixerFullPerm"]:
            """_summary_
                sqrt(c^2 + 8\Gamma/c\gamma) - c
            s =  --------------------------------
                            2
            """
            configure_connections(args)
            print("max_dim:", args.max_dim)
        elif args.net in ["MaskedMLP"]:
            if args.freezing_rate < 0 or args.freezing_rate >1:
                print("freezing_rate must be in [0,1]")
                args.stop_criterion=1
            else:
                """_summary_
                p = 1 - freezing_rate
                \Gamma = ef*w^2*p
                     w = sqrt( \Gamma/(ef*p) )
                Detemine width from connections and  sparsity
                """
                #configure_connections(args)
                p = 1 - args.freezing_rate
                ef = args.expansion_factor if args.expansion_factor > 0 else 1
                squared = args.num_connections/( ef* p)
                args.prod_dim = round(math.sqrt(squared))

                args.max_dim = max(1,ef)*args.prod_dim

                print("width=", args.prod_dim)
                if args.prod_dim < 1:
                    args.stop_criterion=1
                        
        else:
            args.stop_criterion = 1

    elif args.fix == 3:
        """_summary_
        fix: num_connections, prod_dim, num_patches, dim_ppfc(=output of ppfc)
        """
        if args.prod_dim <1:
            args.stop_criterion = 1
        else:
            if args.net in ["MLPMixer", "SMixer", "MLPMixerSymm", "MLPMixerPad", "MLPMixerFullPerm"]:
                """_summary_
                        \gamma m(s+c)
                \Gamma = ---------------
                            2
                        2 \Gamma
                \gamma = ------------------
                        (s+c)m
                """
                c = args.dim
                s = round(args.prod_dim/c)
                ### compute expanding factor from s and c 
                ef = 2*args.num_connections/(args.prod_dim*(s+c) )
                args.expansion_factor = ef
                args.dim_token = s
                args.max_dim = round(ef*s*c)
                args.num_params = (s**2 + c**2)*ef/2
                print("max_dim:", args.max_dim)
                print("ef:", args.expansion_factor)

    
    run = wandb.init(project=f"{args.job_name}", config=args)
    

    if args.stop_criterion == 1:
        print("(main)force return")
        return

    if args.device_id >= 0:
        assert torch.cuda.is_available()
        device = "cuda:{}".format(args.device_id)
    else:
        device =  "cpu"

    args.device= device
    
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)


    patch_size = args.patch_size
    o_dim = 10
    criterion = torch.nn.CrossEntropyLoss()
    
    if args.cos==1:
        schedule = []
    else:
        if  args.schedule > 0:            
            schedule = [args.schedule]
        else:
            schedule = []

    num_classes=10
    if args.dataset == "CIFAR10":
        DATASET = CIFAR10
    elif args.dataset == "CIFAR100":
        DATASET = CIFAR100
        num_classes=100
    elif args.dataset.lower() == "stl10":
        DATASET = STL10

    if DATASET.__name__ in ["Cifar10", "CIFAR10", "CIFAR100",  "STL10"]:
        t = int(args.resize/patch_size)
        i_dims = [patch_size, t]
        aug = args.aug
        if DATASET.__name__ == "STL10":
            normalize=  transforms.Normalize( (0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        else:
            ### cifar
            normalize=transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    
        
        if aug==4:            
            transform = transforms.Compose([
                transforms.RandomCrop(args.resize, padding=4,fill=128),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
                ])

            

        elif aug==6:
            transform = transforms.Compose([
                transforms.RandomCrop(args.resize, padding=4,fill=128),
                transforms.RandomHorizontalFlip(),
                AutoAugment(AutoAugmentPolicy.CIFAR10,fill=128),
                transforms.ToTensor(),
                normalize,
                ])

        else:
            raise ValueError()




        test_transform = transforms.Compose([
        transforms.Resize(args.resize),
        transforms.ToTensor(),
        normalize,
        ])
    
    device = args.device
    
    if args.net == "MLPMixer":
        net =   MLPMixer(
        image_size = args.resize,
        channels = args.channels,
        patch_size = args.patch_size,
        dim = args.dim,
        dim_token = args.dim_token,
        depth = args.L,
        num_classes = num_classes,
        permute_per_blocks=  args.permute,
        expansion_factor=args.expansion_factor,
        expansion_factor_token=args.expansion_factor,
        force_fc_in_skip=args.force_fc_in_skip,
        perm_block_id=args.perm_block_id
        )
    elif args.net == "MaskedMLP":
        scaled_connections = args.num_connections
        if args.expansion_factor > 0:
            scaled_connections /= args.expansion_factor
        net = MaskedMLP(
            image_size=args.resize,
            channels=3,
            width=args.prod_dim,
            depth=args.L,
            num_classes= num_classes,
            expansion_factor=args.expansion_factor,
            freezing_rate=args.freezing_rate,
            patch_size=args.patch_size,
            use_skip=True,
            mask_device= device
            )


    else:
        raise ValueError()
    if args.opt == "sgd":
        optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, nesterov=True)
    elif args.opt == "adamw":
        optimizer = optim.AdamW(net.parameters(), lr=args.lr)
    
    assert len(optimizer.param_groups) == 1 ### Debug
    num_workers = min(int(os.cpu_count()), args.num_workers)


    target_transform = None
    torch.backends.cudnn.benchmark = False ### if true then fast, but we loss reproducibiliy 
    if args.dataset.lower() in ["cifar10", "cifar100" ]:
        train_set = DATASET(root='./data', 
                                                train=True,
                                                download=True,
                                                transform=transform,
                                                target_transform=target_transform)
        test_set = DATASET(root='./data', 
                                                train=False, 
                                                download=True, 
                                                transform=test_transform,
                                                target_transform=target_transform)

    else:
        train_set = DATASET(root='./data', 
                                                split="train",
                                                download=False,
                                                transform=transform,
                                                target_transform=target_transform)
        test_set = DATASET(root='./data', 
                                                split="test", 
                                                download=False, 
                                                transform=test_transform,
                                                target_transform=target_transform)
        

        

    train_loader = torch.utils.data.DataLoader(train_set,
                                                batch_size=args.batch,
                                                shuffle=True,
                                                num_workers=num_workers,
                                                pin_memory=True,
                                                drop_last=True)



    test_loader = torch.utils.data.DataLoader(test_set, 
                                                batch_size=500,
                                                shuffle=False, 
                                                num_workers=4)




    net.to(device)
    count = 0
    best_test_acc=0
    best_test_epoch=0
    best_test_train_acc = 0

    
    for epoch in tqdm(range(args.epoch+args.warmup_epoch)):
        if epoch < args.warmup_epoch:
            warmup(optimizer, epoch, args.lr, args.warmup_epoch)
        else:
            adjust_learning_rate(optimizer, epoch-args.warmup_epoch, args.epoch, args.lr, args.cos, schedule)
        mean_loss = 0
        total_batch = 0
        total_correct = 0
        for x, l in tqdm(train_loader):
            total_batch += x.shape[0]
            x=x.to(device, non_blocking=True)
            l=l.to(device, non_blocking=True)


            optimizer.zero_grad()
            y = net(x) 
            loss = criterion(y, l) 
            loss.backward()

            total_correct +=  _correct(y,l)



            if args.net == "MaskedMLP":
                for m in net.modules():
                    if m._get_name() == "MaskedLinear":
                        m.mask_grad()
            
            optimizer.step()
            count += 1
            if np.isnan(loss.item()) or loss.item() > 1e+8:
                result = False
                break
            mean_loss += loss.item()*x.shape[0]
        test_accuracy, test_loss =\
            test(test_loader, net, criterion, device= device)        

        train_accuracy = total_correct/total_batch
        mean_loss /= total_batch
        
        if best_test_acc <= test_accuracy:
            best_test_acc = test_accuracy
            best_test_epoch = epoch
            best_test_train_acc = train_accuracy
        
        wandb.log({
            "Loss/Test": test_loss, 
            "Loss/Train": mean_loss,
            "Loss/Gap":  mean_loss - test_loss,
            "Accuracy/Test": test_accuracy,
            "Accuracy/Train": train_accuracy,
            "Accuracy/Gap": test_accuracy - train_accuracy,
            "Best/Test/Accuracy": best_test_acc,
            "Best/Test/Train": best_test_train_acc,
            "Best/Test/Epoch": best_test_epoch
        })





if __name__ == "__main__":
    main()
